import pandas as pd
import numpy as np
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score
from sklearn.utils import resample
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy import stats
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 

test = pd.read_csv('./data/polnli_test_results.csv')
results = pd.read_csv('./data/results_matrix.csv')
mps = pd.read_csv('./data/mps_timings.csv')
cuda = pd.read_csv('./data/cuda_timings.csv')
cpu = pd.read_csv('./data/cpu_timings.csv')
t4 = pd.read_csv('./data/t4_timings.csv')

# Functions for calculating metrics and labeling documents
def metrics(df, preds, group_by=None):
    """
    Calculate MCC, Accuracy, F1 for predictions.

    Args:
        df (pd.DataFrame): The input DataFrame containing true and predicted labels.
        preds (list): List of column names containing model predictions.
        group_by (str, optional): Column name to group by ('dataset' or 'task'). Defaults to None.

    Returns:
        pd.DataFrame: DataFrame with calculated metrics, optionally grouped by `group_by`.
    """
    true_col = 'entailment'
    
    def get_metrics(y_true, y_pred):
        return {
            'MCC': matthews_corrcoef(y_true, y_pred),
            'Accuracy': accuracy_score(y_true, y_pred),
            'F1': f1_score(y_true, y_pred, average='macro')
        }
    
    results = []
    
    if group_by not in ['dataset', 'task']:
        for col in preds:
            metrics = get_metrics(df[true_col], df[col])
            metrics['Column'] = col
            results.append(metrics)
    else:
        for col in preds:
            for group_name, group in df.groupby(group_by):
                metrics = get_metrics(group[true_col], group[col])
                metrics['Column'] = col
                metrics[group_by.capitalize()] = group_name
                results.append(metrics)
    
    results_df = pd.DataFrame(results)
    
    if group_by in ['dataset', 'task']:
        return results_df.set_index(['Column', group_by.capitalize()])
    else:
        return results_df.set_index('Column')

def bootstrapped_errors(y_true, y_pred, n_bootstrap=1000):
    """
    Calculate bootstrapped standard errors for MCC, Accuracy, and F1.

    Args:
        y_true (array-like): True labels.
        y_pred (array-like): Predicted labels.
        n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 1000.

    Returns:
        dict: Standard errors for MCC, Accuracy, and F1.
    """
    mcc_scores = []
    accuracy_scores = []
    f1_scores = []
    
    for _ in range(n_bootstrap):
        # Resample with replacement
        y_true_resampled, y_pred_resampled = resample(y_true, y_pred)
        
        # Calculate metrics for the resampled data
        mcc_scores.append(matthews_corrcoef(y_true_resampled, y_pred_resampled))
        accuracy_scores.append(accuracy_score(y_true_resampled, y_pred_resampled))
        f1_scores.append(f1_score(y_true_resampled, y_pred_resampled, average='weighted'))
    
    # Calculate standard errors
    return {
        'MCC_SE': np.std(mcc_scores),
        'Accuracy_SE': np.std(accuracy_scores),
        'F1_SE': np.std(f1_scores)
    }

def metrics_with_errors(df, preds, n_bootstrap=1000, group_by=None):
    """
    Calculate metrics and bootstrapped standard errors for predictions, optionally grouped.

    Args:
        df (pd.DataFrame): The input DataFrame containing true and predicted labels.
        preds (list): List of column names containing model predictions.
        n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 1000.
        group_by (str, optional): Column name to group by ('dataset' or 'task'). Defaults to None.

    Returns:
        pd.DataFrame: Combined DataFrame of metrics, standard errors, and confidence intervals.
    """
    # Step 1: Calculate metrics for each model
    metrics_df = metrics(df, preds, group_by=group_by)

    # Step 2: Calculate bootstrapped errors for each model or group
    errors = []
    if group_by not in ['dataset', 'task']:
        for col in preds:
            y_true = df['entailment']
            y_pred = df[col]
            errors_dict = bootstrapped_errors(y_true, y_pred, n_bootstrap=n_bootstrap)
            errors_dict['Column'] = col
            errors.append(errors_dict)
    else:
        for col in preds:
            for group_name, group in df.groupby(group_by):
                y_true = group['entailment']
                y_pred = group[col]
                errors_dict = bootstrapped_errors(y_true, y_pred, n_bootstrap=n_bootstrap)
                errors_dict['Column'] = col
                errors_dict[group_by.capitalize()] = group_name
                errors.append(errors_dict)

    errors_df = pd.DataFrame(errors)

    if group_by in ['dataset', 'task']:
        errors_df = errors_df.set_index(['Column', group_by.capitalize()])
    else:
        errors_df = errors_df.set_index('Column')

    # Step 3: Merge metrics and errors DataFrames
    combined_df = metrics_df.merge(errors_df, left_index=True, right_index=True)

    # Step 4: Calculate confidence intervals (upper and lower bounds)
    combined_df['MCC_Lower'] = combined_df['MCC'] - combined_df['MCC_SE']
    combined_df['MCC_Upper'] = combined_df['MCC'] + combined_df['MCC_SE']

    combined_df['Accuracy_Lower'] = combined_df['Accuracy'] - combined_df['Accuracy_SE']
    combined_df['Accuracy_Upper'] = combined_df['Accuracy'] + combined_df['Accuracy_SE']

    combined_df['F1_Lower'] = combined_df['F1'] - combined_df['F1_SE']
    combined_df['F1_Upper'] = combined_df['F1'] + combined_df['F1_SE']

    return combined_df

# add llama and sonnet to the models list
columns = ['base_nli',
           'large_nli',
           'base_debate',
           'large_debate',
           'base_modern',
           'large_modern',
           'llama',
           'llama70b',
           'sonnet']

###########
## Table 8
###########
# Calculate performance metrics with bootstrapped standard errors. n_bootstrap == 1000
overall = metrics_with_errors(test, columns, group_by = None)
task = metrics_with_errors(test, columns, group_by = 'task')
dataset = metrics_with_errors(test, columns, group_by = 'dataset')

# rename columns and concat results
overall.reset_index(inplace = True, drop = False)
overall['Task'] = 'overall'
overall.rename({'Column':'Model'}, axis = 1, inplace = True)

task.reset_index(inplace = True, drop = False)
task.rename({'Column':'Model'}, axis = 1, inplace = True)

dataset.reset_index(inplace = True, drop = False)
dataset.rename({'Dataset':'Task'}, axis = 1, inplace = True)
dataset['Task'] = dataset['Task'].str.replace('mlburnham/', '')
dataset.rename({'Column':'Model'}, axis = 1, inplace = True)

task = pd.concat([task, overall])
# Pivot to create f1 table
pivot_F1 = task[['Model', 'Task', 'F1']].pivot(index = ['Model'], columns = 'Task', values = 'F1')
pivot_F1.columns = ['Event', 'Hatespeech', 'Overall', 'Stance', 'Topic']
pivot_F1 = pivot_F1[['Stance', 'Topic', 'Hatespeech', 'Event', 'Overall']]

# Standard errors table
pivot_F1SE = task[['Model', 'Task', 'F1_SE']].pivot(index = ['Model'], columns = 'Task', values = 'F1_SE')
pivot_F1SE.columns = ['Event', 'Hatespeech', 'Overall', 'Stance', 'Topic']
pivot_F1SE = pivot_F1SE[['Stance', 'Topic', 'Hatespeech', 'Event', 'Overall']]

# Merge into a single table
merged = pd.DataFrame(index=pivot_F1.index)
for col in pivot_F1.columns:
    merged[col] = pivot_F1[col]
    merged[f"{col}_SE"] = pivot_F1SE[col]

# Convert to % and round
merged.iloc[:, :] = merged.iloc[:, :].mul(100).round(2)

# Order the rows
model_order = [
    "base_nli",
    "large_nli",
    "base_debate",
    "large_debate",
    "base_modern",
    "large_modern",
    "llama",
    "llama70b",
    "sonnet"
]
# Reindex rows with that order
merged = merged.loc[model_order].reset_index()

output_file = './tables/table_8.txt'
pivot_string = merged.to_string()
with open(output_file, 'w') as f:
    f.write(pivot_string)

############
## Figure 7
############

# rename models for plotting
results.replace({'base_nli': 'NLI Base', 
            'large_nli':'NLI Large', 
            'base_debate':'DEBATE Base',
            'llama':'Llama 3.1 8B',
            'llama70b':'Llama 3.3 70B',
            'large_debate':'DEBATE Large',
            'sonnet':'Claude 3.5',
            'base_modern': 'DEBATE Base (MB)',
            'large_modern': 'DEBATE Large (MB)',
            'event extraction': 'Events',
            'hatespeech and toxicity':'Hatespeech',
            'stance detection':'Stance',
            'topic classification':'Topic'
           }, inplace = True)

results.reset_index(drop = True, inplace = True)

metric = 'F1'
tasks = ['PoliStance_Affect', 'PoliStance_Affect_QT',
       'acled_event_entailment', 'argument_quality_ranking_entailment',
       'bill_summary_entailment', 'dehumanizing_hatespeech_entailment',
       'dem_rep_party_platform_topics', 'ibm_claimstance_entailment',
       'ibm_claimstance_topic_entailment', 'polistance_issue_tweets',
       'scad_event_entailment', 'targeted_hatespeech_entailment',
       'violent_hatespeech_entailment']

modelorder = ["DEBATE Large", "DEBATE Large (MB)", "DEBATE Base", "DEBATE Base (MB)", 
              "Claude 3.5", "Llama 3.3 70B", "NLI Large", "Llama 3.1 8B", "NLI Base"]

data = results[results['Task'].isin(tasks)].copy()
data['Model'] = pd.Categorical(data['Model'], categories=modelorder, ordered=True)
bar_color = sns.color_palette("colorblind")[0]

# Set the figure size for better visibility
plt.figure(figsize=(14, 8))

# Create the box plot
sns.boxplot(
    x=metric,           # X-axis
    y='Model',        # Y-axis
    data=data,         # Your DataFrame
    fill=bar_color  # Use a colorblind-friendly palette
)
sns.despine(top=True, right=True)
# Label the axes
plt.xlabel(metric, fontweight = 'bold', size = 24)
plt.ylabel('')
#plt.tick_params(axis='y', which='major', labelsize=18)
plt.xticks(ticks = [.5, .6, .7, .8, .9, 1], labels = ['50%', '60%', '70%', '80%', '90%', '100%'])
plt.yticks(fontsize = 24, fontweight = 'bold')
plt.grid(axis='x', which='both', linestyle='--', alpha=0.7)
# Improve layout
plt.tight_layout()
plt.savefig(r'./figures/figure_7.png', dpi = 300)

############
## Figure 8
############
timing = pd.concat([mps, cuda, cpu, t4])

timing['Model'].replace({'Political_DEBATE_DeBERTa_base_v1.1': 'DEBATE Base',
       'Political_DEBATE_ModernBERT_base_v1.0': 'DEBATE Base (MB)',
       'Political_DEBATE_DeBERTa_large_v1.1': 'DEBATE Large',
       'Political_DEBATE_ModernBERT_large_v1.0': 'DEBATE Large (MB)',
       'Meta-Llama-3.1-8B-Instruct': 'Llama 3.1 8B'}, inplace = True)

timing.replace({'mps': 'Apple M3', 't4':'Tesla T4 GPU', 'cuda':'RTX 3090', 'cpu':'Ryzen 9900x', 't4': 'T4'}, inplace = True)
timing.reset_index(drop = True, inplace = True)
# Add row of zeros  for CPU bench for llama
timing.loc[len(timing)] = ['Llama 3.1 8B', 'Ryzen 9900x', 0, 0, 0, 0]

sns.set_palette("colorblind")

hardwareorder = ["RTX 3090", "Apple M3", "T4", "Ryzen 9900x"]

models = ['DEBATE Base (MB)', 'DEBATE Large (MB)', 'DEBATE Base', 'DEBATE Large', 'Llama 3.1 8B']
data = timing[timing['Model'].isin(models)]
data['Hardware'] = pd.Categorical(data['Hardware'], categories=hardwareorder, ordered=True)
data['Model'] = pd.Categorical(data['Model'], categories=models, ordered=True)
# Create a grouped bar plot
plt.figure(figsize=(14, 6))

# Plot the bars
plt.figure(figsize=(14, 6))

# Plot the bars
barplot = sns.barplot(x='Model', y='DPS', hue='Hardware', data=data)

# Add error bars
num_models = len(data['Model'].unique())  # Number of unique models
num_hardware = len(data['Hardware'].unique())  # Number of unique hardware types
bar_width = 0.8 / num_hardware  # Dynamically adjust bar width

for bar in barplot.patches:
    bar_height = bar.get_height()  # Get the height of the bar (DPS value)
    bar_x = bar.get_x() + bar.get_width() / 2  # Get the x-coordinate for the text
    if not pd.isna(bar_height) and bar_height > 0:  # Ensure the height is not NaN
        barplot.annotate(
            f'{bar_height:.0f}',  # Format the number as an integer
            (bar_x, bar_height),  # Position the text at the top of the bar
            ha='center',  # Center the text horizontally
            va='bottom',  # Place the text above the bar
            fontsize=16,  # Set font size
            fontweight='bold'  # Make the text bold
        )

# Remove the X-axis label
plt.xlabel('')

# Make X-axis tick labels bold
custom_labels = ['DEBATE Base\n(MB)', 'DEBATE Large\n(MB)', 'DEBATE Base\n(DeBERTa)', 'DEBATE Large\n(DeBERTa)', 'Llama 3.1 8B']
plt.xticks(ticks=range(len(custom_labels)), labels=custom_labels, fontweight='bold', fontsize=20)

#plt.xticks(fontweight='bold', fontsize = 20)

# Make Y-axis label bold
plt.ylabel('Documents Per Second', fontweight='bold', fontsize = 24)
#plt.ylim(0,2200)
# Remove the legend title
plt.legend(title='')

sns.despine()  # Remove top and right spines
plt.grid(False)

# Show the plot
plt.tight_layout()
plt.savefig('./figures/figure_8.png', dpi=300)